#include "duckdb.hpp"
#ifndef DUCKDB_AMALGAMATION
#include "duckdb/execution/operator/aggregate/physical_window.hpp"
#include "duckdb/planner/expression.hpp"
#include "duckdb/planner/expression/bound_window_expression.hpp"
#include "duckdb/planner/expression/bound_constant_expression.hpp"
#include "duckdb/planner/expression/bound_reference_expression.hpp"
#include "duckdb/parallel/thread_context.hpp"
#include "duckdb/function/aggregate/distributive_functions.hpp"
#endif

using namespace duckdb;

int main() {
	DuckDB db(nullptr);
	Connection con(db);

	// in this example we show how to execute the following window function/query fragment
	// select a,b, sum(a) over (partition by b order by a rows between unbounded preceding and 0 following) from
	// integers;

	// the input types are the types we pass INTO the window function (i.e. the input to the window functions)
	// in our example this is {a, b}
	vector<LogicalType> input_types {LogicalType::BIGINT, LogicalType::BIGINT};

	// construct the expressions
	// expressions can either be built-in specific window functions (e.g. row_number, rank, ntile, etc) or combinable
	// aggregate functions sum(a) over (partition by b order by a rows between unbounded preceding and 0 following)
	auto aggregate = SumFun::GetSumAggregate(PhysicalType::INT64);
	auto return_type = aggregate.return_type;
	auto sum = make_uniq<BoundWindowExpression>(ExpressionType::WINDOW_AGGREGATE, return_type,
	                                            make_uniq<AggregateFunction>(aggregate), nullptr);
	sum->start = WindowBoundary::UNBOUNDED_PRECEDING; // unbounded preceding
	sum->end = WindowBoundary::EXPR_FOLLOWING;        // following 0 (0 is an expression)
	sum->end_expr = make_uniq<BoundConstantExpression>(Value::BIGINT(0));
	// a is child 0, b is child 1
	// normally these reference expressions are generated by our binder, but we can manually construct them
	sum->children.push_back(make_uniq<BoundReferenceExpression>(input_types[0], 0));   // sum(a)
	sum->partitions.push_back(make_uniq<BoundReferenceExpression>(input_types[1], 1)); // partition by b
	// order by a
	BoundOrderByNode node(OrderType::ASCENDING, OrderByNullType::NULLS_FIRST,
	                      make_uniq<BoundReferenceExpression>(input_types[0], 0) // a
	);
	sum->orders.push_back(std::move(node));

	// the return types are (1) the input types, and (2) the return types of the computed window functions
	// the window function also returns all input columns rather than only the result columns of the window function
	// this is because the window functions are generally used alongside other projections, e.g.
	// -> select a, b, row_number() over () from tbl
	// in this case the window function will receive [a, b] as input, and output [a, b, row_num] as output
	vector<LogicalType> result_types {input_types[0], input_types[1], sum->return_type};
	vector<duckdb::unique_ptr<Expression>> expressions;
	expressions.push_back(std::move(sum));
	// construct the window operator
	auto window = make_uniq<PhysicalWindow>(result_types, std::move(expressions), 0);

	// now we can run the window function
	// first set up some contexts
	auto &client_context = *con.context;
	ThreadContext thread_context(client_context);
	ExecutionContext econtext(client_context, thread_context);

	// global state needs to be shared amongst all threads
	// local_state is thread-local - every thread should have their own
	auto global_state = window->GetGlobalSinkState(client_context);
	auto local_state = window->GetLocalSinkState(econtext);

	// the actual computation
	// we need to insert some data...
	// we generate this data: create table integers as select i % 15 a, i % 5 b from generate_series(2048) tbl(i);

	DataChunk chunk;
	chunk.Initialize(input_types);
	const idx_t vector_count = 2;
	int64_t current_val = 0;
	for (idx_t i = 0; i < vector_count; i++) {
		// set up the chunk to insert into the window function
		for (idx_t i = 0; i < STANDARD_VECTOR_SIZE; i++) {
			chunk.SetValue(0, i, Value::BIGINT(current_val % 15));
			chunk.SetValue(1, i, Value::BIGINT(current_val % 5));
			current_val++;
		}
		chunk.SetCardinality(STANDARD_VECTOR_SIZE);

		// actually feed the data into the window function
		window->Sink(econtext, *global_state, *local_state, chunk);
	}

	// now combine
	// this should happen once per thread
	window->Combine(econtext, *global_state, *local_state);

	// now finalize
	// this should happen once in total, after every thread has been combined
	window->FinalizeInternal(client_context, std::move(global_state));

	// after the window function is finalized we can pull the result from it using the GetChunk method
	DataChunk result;
	result.Initialize(result_types);
	auto global_source_state = window->GetGlobalSourceState(client_context);
	auto local_source_state = window->GetLocalSourceState(econtext, *global_source_state);
	while (true) {
		window->GetData(econtext, result, *global_source_state, *local_source_state);
		if (result.size() == 0) {
			break;
		}
		result.Print();
	}
}
